package top.lingkang.finaloauth2.server.store.impl;

import com.alibaba.fastjson.JSONObject;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.BeanPropertyRowMapper;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import top.lingkang.finaloauth2.entity.*;
import top.lingkang.finaloauth2.server.store.TokenStore;

import javax.sql.DataSource;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Timer;
import java.util.TimerTask;

/**
 * @author lingkang
 * Created by 2022/3/27
 * 序列化采用 fastjson，性能测试请查看：https://blog.csdn.net/weixin_44480167/article/details/123783730
 * 接口对象无法使用json序列化，于是选用jdk序列化
 */
public class JdbcTokenStore implements TokenStore {
    private static final Log log = LogFactory.getLog(JdbcTokenStore.class);
    private final JdbcTemplate jdbcTemplate;

    public JdbcTokenStore(DataSource dataSource) {
        this.jdbcTemplate = new JdbcTemplate(dataSource);
        log.info("使用jdbc存储token，ok");
        new Timer().schedule(new TimerTask() {
            @Override
            public void run() {
                cleanExpiryCode();
                cleanExpiryToken();
            }
        }, 10 * 1000, 60 * 60 * 8 * 1000);
    }

    @Override
    public OAuth2AccessToken getAccessToken(String token) {
        try {
            List<OAuth2AccessToken> query = jdbcTemplate.query(
                    "select token,refresh_token,expiry,scope from fo_token where token=?",
                    new Object[]{token},
                    new RowMapper<OAuth2AccessToken>() {
                        @Override
                        public OAuth2AccessToken mapRow(ResultSet resultSet, int i) throws SQLException {
                            OAuth2AccessToken token = new OAuth2AccessToken();
                            token.setToken(resultSet.getString(1));
                            token.setRefreshToken(resultSet.getString(2));
                            token.setExpiry(resultSet.getLong(3));
                            token.setScope(resultSet.getString(4));
                            return token;
                        }
                    });
            if (!query.isEmpty()) {
                OAuth2AccessToken accessToken = query.get(0);
                // 到期移除
                if (accessToken.getExpiry() < System.currentTimeMillis()) {
                    removeAccessToken(token);
                    return null;
                }
                return accessToken;
            }
        } catch (DataAccessException e) {
            log.error(e);
        }
        return null;
    }

    @Override
    public UserTokenInfo getUserTokenInfo(String username, String clientId) {
        List<UserTokenInfo> query = jdbcTemplate.query(
                "select token,refresh_token as refreshToken from fo_user_details where username=? and client_id=?",
                new String[]{username, clientId},
                new BeanPropertyRowMapper<>(UserTokenInfo.class)
        );
        if (query.isEmpty()) {
            return null;
        }
        UserTokenInfo info = query.get(0);

        if (info.getToken() != null) {
            OAuth2AccessToken accessToken = getAccessToken(info.getToken());
            if (accessToken != null) {
                info.setTokenExpires(accessToken.getExpiry());
            } else {
                info.setToken(null);
            }
        }
        if (info.getRefreshToken() != null) {
            OAuth2RefreshToken refreshToken = getRefreshToken(info.getRefreshToken());
            if (refreshToken != null) {
                info.setRefreshExpires(refreshToken.getExpiry());
            } else {
                info.setRefreshToken(null);
            }
        }
        return info;
    }

    @Override
    public OAuth2AccessToken removeAccessToken(String token) {
        OAuth2AccessToken accessToken = null;
        List<OAuth2AccessToken> query = jdbcTemplate.query(
                "select token,refresh_token,expiry,scope from fo_token where token=?",
                new Object[]{token},
                new RowMapper<OAuth2AccessToken>() {
                    @Override
                    public OAuth2AccessToken mapRow(ResultSet resultSet, int i) throws SQLException {
                        OAuth2AccessToken token = new OAuth2AccessToken();
                        token.setToken(resultSet.getString(1));
                        token.setRefreshToken(resultSet.getString(2));
                        token.setExpiry(resultSet.getLong(3));
                        token.setScope(resultSet.getString(4));
                        return token;
                    }
                });
        if (!query.isEmpty()) {
            accessToken = query.get(0);
            jdbcTemplate.update("delete from fo_token where token=?", token);
        }
        // jdbcTemplate.update("delete from fo_user_details where token=?", token);
        return accessToken;
    }

    @Override
    public void storeAccessToken(OAuth2AccessToken accessToken, UserDetails userDetails) {
        // 先删除
        jdbcTemplate.update("delete from fo_token where token=?", accessToken.getToken());
        // 插入token表
        jdbcTemplate.update("insert into fo_token(token,refresh_token,expiry,scope) values (?,?,?,?)",
                accessToken.getToken(),
                accessToken.getRefreshToken(),
                accessToken.getExpiry(),
                accessToken.getScope()
        );

        Integer has = jdbcTemplate.queryForObject(
                "select count(*) from fo_user_details where username=? and client_id=?",
                new String[]{userDetails.getUsername(), userDetails.getClientId()},
                Integer.class
        );
        if (has == 0) {
            // 插入用户详情表
            jdbcTemplate.update("insert into fo_user_details(username,client_id,token,refresh_token,details) values (?,?,?,?,?)",
                    userDetails.getUsername(),
                    userDetails.getClientId(),
                    accessToken.getToken(),
                    accessToken.getRefreshToken(),
                    JSONObject.toJSONString(userDetails)
            );
        } else {
            // 更新
            jdbcTemplate.update("update fo_user_details set token=?,details=? where username=? and client_id=?",
                    accessToken.getToken(),
                    JSONObject.toJSONString(userDetails),
                    userDetails.getUsername(),
                    userDetails.getClientId()
            );
        }
    }

    @Override
    public UserDetails getUserDetails(String token) {
        try {
            OAuth2AccessToken accessToken = getAccessToken(token);
            if (accessToken != null) {
                List<String> query = jdbcTemplate.query(
                        "select details from fo_user_details where token=?",
                        new Object[]{token},
                        new RowMapper<String>() {
                            @Override
                            public String mapRow(ResultSet resultSet, int i) throws SQLException {
                                return resultSet.getString(1);
                            }
                        });
                if (!query.isEmpty()) {
                    return JSONObject.parseObject(query.get(0), UserDetails.class); // 反序列化
                }
            }
        } catch (Exception e) {
            log.error(e);
        }
        return null;
    }

    @Override
    public OAuth2RefreshToken getRefreshToken(String refreshToken) {
        try {
            List<OAuth2RefreshToken> query = jdbcTemplate.query(
                    "select refresh_token,access_token,user_details,expiry from fo_refresh_token where refresh_token=?",
                    new Object[]{refreshToken},
                    new RowMapper<OAuth2RefreshToken>() {
                        @Override
                        public OAuth2RefreshToken mapRow(ResultSet resultSet, int i) throws SQLException {
                            OAuth2RefreshToken refreshToken = new OAuth2RefreshToken();
                            refreshToken.setRefreshToken(resultSet.getString(1));
                            OAuth2AccessToken accessToken = JSONObject.parseObject(resultSet.getString(2), OAuth2AccessToken.class);
                            refreshToken.setAccessToken(accessToken);
                            refreshToken.setUserDetails(JSONObject.parseObject(resultSet.getString(3), UserDetails.class));
                            refreshToken.setExpiry(resultSet.getLong(4));
                            return refreshToken;
                        }
                    });
            if (!query.isEmpty()) {
                OAuth2RefreshToken auth2RefreshToken = query.get(0);
                if (auth2RefreshToken.getExpiry() < System.currentTimeMillis()) {
                    removeRefreshToken(refreshToken);// 到期移除
                    return null;
                }
                return auth2RefreshToken;
            }
        } catch (DataAccessException e) {
            log.error(e);
        }
        return null;
    }

    @Override
    public OAuth2RefreshToken removeRefreshToken(String refreshToken) {
        OAuth2RefreshToken auth2RefreshToken = null;
        List<OAuth2RefreshToken> query = jdbcTemplate.query(
                "select refresh_token,access_token,user_details,expiry from fo_refresh_token where refresh_token=?",
                new Object[]{refreshToken},
                new RowMapper<OAuth2RefreshToken>() {
                    @Override
                    public OAuth2RefreshToken mapRow(ResultSet resultSet, int i) throws SQLException {
                        OAuth2RefreshToken refreshToken = new OAuth2RefreshToken();
                        refreshToken.setRefreshToken(resultSet.getString(1));
                        OAuth2AccessToken accessToken = JSONObject.parseObject(resultSet.getString(2), OAuth2AccessToken.class);
                        refreshToken.setAccessToken(accessToken);
                        UserDetails userDetails = JSONObject.parseObject(resultSet.getString(3), UserDetails.class);
                        refreshToken.setUserDetails(userDetails);
                        refreshToken.setExpiry(resultSet.getLong(4));
                        return refreshToken;
                    }
                });
        if (!query.isEmpty()) {
            auth2RefreshToken = query.get(0);
            String token = auth2RefreshToken.getAccessToken().getToken();
            removeAccessToken(token);// 移除刷新令牌归属的token
            // 删除
            jdbcTemplate.update("delete from fo_refresh_token where refresh_token=?", refreshToken);
        }
        return auth2RefreshToken;
    }

    @Override
    public void storeRefreshToken(OAuth2RefreshToken refreshToken) {
        jdbcTemplate.update("delete from fo_refresh_token where refresh_token=?", refreshToken.getRefreshToken());// 先删除
        // 插入token表
        jdbcTemplate.update("insert into fo_refresh_token(refresh_token,access_token,user_details,expiry) values (?,?,?,?)",
                refreshToken.getRefreshToken(),
                JSONObject.toJSONString(refreshToken.getAccessToken()),
                JSONObject.toJSONString(refreshToken.getUserDetails()),
                refreshToken.getExpiry()
        );
    }

    @Override
    public void setCode(Oauth2Code oauth2Code) {
        jdbcTemplate.update("insert into fo_code(code,expiry,scope,client_id,state,redirect_url) values(?,?,?,?,?,?)",
                oauth2Code.getCode(),
                oauth2Code.getExpiry(),
                JSONObject.toJSONString(oauth2Code.getScope()),
                oauth2Code.getClientId(),
                oauth2Code.getState(),
                oauth2Code.getRedirectUri()
        );
    }

    @Override
    public Oauth2Code getOauth2Code(String code) {
        try {
            List<Oauth2Code> query = jdbcTemplate.query(
                    "select code,expiry,scope,client_id as clientId from fo_code where code=?",
                    new String[]{code},
                    new BeanPropertyRowMapper<>(Oauth2Code.class)
            );
            if (query.isEmpty()) {
                return null;
            }
            jdbcTemplate.update("delete from fo_code where code=?", code);
            if (query.get(0).getExpiry() < System.currentTimeMillis()) {
                return null;
            }
            return query.get(0);
        } catch (DataAccessException e) {
            // empty
        }
        return null;
    }

    private void cleanExpiryToken() {
        long current = System.currentTimeMillis();
        // 删除过期用户详情
        int u = jdbcTemplate.update(
                "delete u.* from fo_user_details u left join fo_token t on u.token=t.token" +
                        " where t.expiry<?",
                current
        );

        // 删除token
        int t = jdbcTemplate.update("delete from fo_token where expiry<?", current);

        // 删除刷新token
        int r = jdbcTemplate.update("delete from fo_refresh_token where expiry<?", current);

        log.info("清除无效个数：user=" + u + "  token=" + t + "  refresh_token=" + r);
    }

    private void cleanExpiryCode() {
        int c = jdbcTemplate.update("delete from fo_code where expiry<?", System.currentTimeMillis());
        log.info("清除无效授权码个数：" + c);
    }
}
